{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Remove old Weaviate DB files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "!rm -rf ~/.local/share/weaviate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Approximate Nearest Neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "from random import random, randint\n",
    "from math import floor, log\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib as mtplt\n",
    "from matplotlib import pyplot as plt\n",
    "from utils import *\n",
    "\n",
    "vec_num = 40 # Number of vectors (nodes)\n",
    "dim = 2 ## Dimention. Set to be 2. All the graph plots are for dim 2. If changed, then plots should be commented. \n",
    "m_nearest_neighbor = 2 # M Nearest Neigbor used in construction of the Navigable Small World (NSW)\n",
    "\n",
    "vec_pos = np.random.uniform(size=(vec_num, dim))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query Vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "## Query\n",
    "query_vec = [0.5, 0.5]\n",
    "\n",
    "nodes = []\n",
    "nodes.append((\"Q\",{\"pos\": query_vec}))\n",
    "\n",
    "G_query = nx.Graph()\n",
    "G_query.add_nodes_from(nodes)\n",
    "\n",
    "print(\"nodes = \", nodes, flush=True)\n",
    "\n",
    "pos_query=nx.get_node_attributes(G_query,'pos')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Brute Force"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "(G_lin, G_best) = nearest_neigbor(vec_pos,query_vec)\n",
    "\n",
    "pos_lin=nx.get_node_attributes(G_lin,'pos')\n",
    "pos_best=nx.get_node_attributes(G_best,'pos')\n",
    "\n",
    "fig, axs = plt.subplots()\n",
    "\n",
    "nx.draw(G_lin, pos_lin, with_labels=True, node_size=150, node_color=[[0.8,0.8,1]], width=0.0, font_size=7, ax = axs)\n",
    "nx.draw(G_query, pos_query, with_labels=True, node_size=200, node_color=[[0.5,0,0]], font_color='white', width=0.5, font_size=7, font_weight='bold', ax = axs)\n",
    "nx.draw(G_best, pos_best, with_labels=True, node_size=200, node_color=[[0.85,0.7,0.2]], width=0.5, font_size=7, font_weight='bold', ax = axs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HNSW Construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 301
   },
   "outputs": [],
   "source": [
    "GraphArray = construct_HNSW(vec_pos,m_nearest_neighbor)\n",
    "\n",
    "for layer_i in range(len(GraphArray)-1,-1,-1):\n",
    "    fig, axs = plt.subplots()\n",
    "\n",
    "    print(\"layer_i = \", layer_i)\n",
    "        \n",
    "    if layer_i>0:\n",
    "        pos_layer_0 = nx.get_node_attributes(GraphArray[0],'pos')\n",
    "        nx.draw(GraphArray[0], pos_layer_0, with_labels=True, node_size=120, node_color=[[0.9,0.9,1]], width=0.0, font_size=6, font_color=(0.65,0.65,0.65), ax = axs)\n",
    "\n",
    "    pos_layer_i = nx.get_node_attributes(GraphArray[layer_i],'pos')\n",
    "    nx.draw(GraphArray[layer_i], pos_layer_i, with_labels=True, node_size=150, node_color=[[0.7,0.7,1]], width=0.5, font_size=7, ax = axs)\n",
    "    nx.draw(G_query, pos_query, with_labels=True, node_size=200, node_color=[[0.8,0,0]], width=0.5, font_size=7, font_weight='bold', ax = axs)\n",
    "    nx.draw(G_best, pos_best, with_labels=True, node_size=200, node_color=[[0.85,0.7,0.2]], width=0.5, font_size=7, font_weight='bold', ax = axs)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HNSW Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 403
   },
   "outputs": [],
   "source": [
    "(SearchPathGraphArray, EntryGraphArray) = search_HNSW(GraphArray,G_query)\n",
    "\n",
    "for layer_i in range(len(GraphArray)-1,-1,-1):\n",
    "    fig, axs = plt.subplots()\n",
    "\n",
    "    print(\"layer_i = \", layer_i)\n",
    "    G_path_layer = SearchPathGraphArray[layer_i]\n",
    "    pos_path = nx.get_node_attributes(G_path_layer,'pos')\n",
    "    G_entry = EntryGraphArray[layer_i]\n",
    "    pos_entry = nx.get_node_attributes(G_entry,'pos')\n",
    "\n",
    "    if layer_i>0:\n",
    "            pos_layer_0 = nx.get_node_attributes(GraphArray[0],'pos')\n",
    "            nx.draw(GraphArray[0], pos_layer_0, with_labels=True, node_size=120, node_color=[[0.9,0.9,1]], width=0.0, font_size=6, font_color=(0.65,0.65,0.65), ax = axs)\n",
    "\n",
    "    pos_layer_i = nx.get_node_attributes(GraphArray[layer_i],'pos')\n",
    "    nx.draw(GraphArray[layer_i], pos_layer_i, with_labels=True, node_size=100, node_color=[[0.7,0.7,1]], width=0.5, font_size=6, ax = axs)\n",
    "    nx.draw(G_path_layer, pos_path, with_labels=True, node_size=110, node_color=[[0.8,1,0.8]], width=0.5, font_size=6, ax = axs)\n",
    "    nx.draw(G_query, pos_query, with_labels=True, node_size=80, node_color=[[0.8,0,0]], width=0.5, font_size=7, ax = axs)\n",
    "    nx.draw(G_best, pos_best, with_labels=True, node_size=70, node_color=[[0.85,0.7,0.2]], width=0.5, font_size=7, ax = axs)\n",
    "    nx.draw(G_entry, pos_entry, with_labels=True, node_size=80, node_color=[[0.1,0.9,0.1]], width=0.5, font_size=7, ax = axs)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pure Vector Search - with a vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "import weaviate, json\n",
    "from weaviate import EmbeddedOptions\n",
    "\n",
    "client = weaviate.Client(\n",
    "    embedded_options=EmbeddedOptions(),\n",
    ")\n",
    "\n",
    "client.is_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 284
   },
   "outputs": [],
   "source": [
    "# resetting the schema. CAUTION: This will delete your collection \n",
    "# if client.schema.exists(\"MyCollection\"):\n",
    "#     client.schema.delete_class(\"MyCollection\")\n",
    "\n",
    "schema = {\n",
    "    \"class\": \"MyCollection\",\n",
    "    \"vectorizer\": \"none\",\n",
    "    \"vectorIndexConfig\": {\n",
    "        \"distance\": \"cosine\" # let's use cosine distance\n",
    "    },\n",
    "}\n",
    "\n",
    "client.schema.create_class(schema)\n",
    "\n",
    "print(\"Successfully created the schema.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 471
   },
   "outputs": [],
   "source": [
    "data = [\n",
    "   {\n",
    "      \"title\": \"First Object\",\n",
    "      \"foo\": 99, \n",
    "      \"vector\": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]\n",
    "   },\n",
    "   {\n",
    "      \"title\": \"Second Object\",\n",
    "      \"foo\": 77, \n",
    "      \"vector\": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7]\n",
    "   },\n",
    "   {\n",
    "      \"title\": \"Third Object\",\n",
    "      \"foo\": 55, \n",
    "      \"vector\": [0.3, 0.1, -0.1, -0.3, -0.5, -0.7]\n",
    "   },\n",
    "   {\n",
    "      \"title\": \"Fourth Object\",\n",
    "      \"foo\": 33, \n",
    "      \"vector\": [0.4, 0.41, 0.42, 0.43, 0.44, 0.45]\n",
    "   },\n",
    "   {\n",
    "      \"title\": \"Fifth Object\",\n",
    "      \"foo\": 11,\n",
    "      \"vector\": [0.5, 0.5, 0, 0, 0, 0]\n",
    "   },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "client.batch.configure(batch_size=10)  # Configure batch\n",
    "\n",
    "# Batch import all objects\n",
    "# yes batch is an overkill for 5 objects, but it is recommended for large volumes of data\n",
    "with client.batch as batch:\n",
    "  for item in data:\n",
    "\n",
    "      properties = {\n",
    "         \"title\": item[\"title\"],\n",
    "         \"foo\": item[\"foo\"],\n",
    "      }\n",
    "\n",
    "      # the call that performs data insert\n",
    "      client.batch.add_data_object(\n",
    "         class_name=\"MyCollection\",\n",
    "         data_object=properties,\n",
    "         vector=item[\"vector\"] # your vector embeddings go here\n",
    "      )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "# Check number of objects\n",
    "response = (\n",
    "    client.query\n",
    "    .aggregate(\"MyCollection\")\n",
    "    .with_meta_count()\n",
    "    .do()\n",
    ")\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query Weaviate: Vector Search (vector embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "response = (\n",
    "    client.query\n",
    "    .get(\"MyCollection\", [\"title\"])\n",
    "    .with_near_vector({\n",
    "        \"vector\": [-0.012, 0.021, -0.23, -0.42, 0.5, 0.5]\n",
    "    })\n",
    "    .with_limit(2) # limit the output to only 2\n",
    "    .do()\n",
    ")\n",
    "\n",
    "result = response[\"data\"][\"Get\"][\"MyCollection\"]\n",
    "print(json.dumps(result, indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "response = (\n",
    "    client.query\n",
    "    .get(\"MyCollection\", [\"title\"])\n",
    "    .with_near_vector({\n",
    "        \"vector\": [-0.012, 0.021, -0.23, -0.42, 0.5, 0.5]\n",
    "    })\n",
    "    .with_limit(2) # limit the output to only 2\n",
    "    .with_additional([\"distance\", \"vector, id\"])\n",
    "    .do()\n",
    ")\n",
    "\n",
    "result = response[\"data\"][\"Get\"][\"MyCollection\"]\n",
    "print(json.dumps(result, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vector Search with filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "response = (\n",
    "    client.query\n",
    "    .get(\"MyCollection\", [\"title\", \"foo\"])\n",
    "    .with_near_vector({\n",
    "        \"vector\": [-0.012, 0.021, -0.23, -0.42, 0.5, 0.5]\n",
    "    })\n",
    "    .with_additional([\"distance, id\"]) # output the distance of the query vector to the objects in the database\n",
    "    .with_where({\n",
    "        \"path\": [\"foo\"],\n",
    "        \"operator\": \"GreaterThan\",\n",
    "        \"valueNumber\": 44\n",
    "    })\n",
    "    .with_limit(2) # limit the output to only 2\n",
    "    .do()\n",
    ")\n",
    "\n",
    "result = response[\"data\"][\"Get\"][\"MyCollection\"]\n",
    "print(json.dumps(result, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### nearObject Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 233
   },
   "outputs": [],
   "source": [
    "response = (\n",
    "    client.query\n",
    "    .get(\"MyCollection\", [\"title\"])\n",
    "    .with_near_object({ # the id of the the search object\n",
    "        \"id\": result[0]['_additional']['id']\n",
    "    })\n",
    "    .with_limit(3)\n",
    "    .with_additional([\"distance\"])\n",
    "    .do()\n",
    ")\n",
    "\n",
    "result = response[\"data\"][\"Get\"][\"MyCollection\"]\n",
    "print(json.dumps(result, indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
